# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import argparse
import logging
import math
import os
import pathlib
import sys
from typing import Any, NamedTuple

import yaml
from decomp_interface_gen_op_list import (
    decomp_rule_interface_declare_gen_op_list,
    decomp_vjp_interface_declare_gen_op_list,
)
from gen_utils import attr_types_map, to_pascal_case
from infer_symbolic_shape_gen import gen_infer_symbolic_shape_str
from op_all_func_gen import gen_op_all_func
from op_build_gen import gen_build_func_str, gen_build_func_str_by_invoke
from op_interface_gen import (
    gen_exclusive_interface_str,
    gen_op_vjp_str,
)
from op_kerneltype_gen import gen_kernel_type_for_var_str
from op_verify_gen import gen_verify_func_str
from ops_onednn_extra_parser import parse_data_format_tensors, parse_extra_args
from parse_kernel_key_gen import gen_parse_kernel_key_str
from vjp_interface_black_list import vjp_interface_black_list

# import from paddle/fluid/primitive/code_gen/gen.py
sys.path.append(
    str(pathlib.Path(__file__).resolve().parents[3] / 'primitive/codegen')
)

import decomp_vjp_gen as vjp_gen

# Note(Galaxy1458) The need_export_symbol_op_list is used
# for some unittests these need to export symbol op compiled with dynamic lib.
need_export_symbol_op_list = [
    'Add_Op',
    'AddNOp',
    'AbsOp',
    'FullOp',
    'UniformOp',
    'ScaleOp',
    'AddOp',
    'Conv2dOp',
    'BatchNormOp',
    'FetchOp',
    'FullIntArrayOp',
    'FusedConv2dAddActOp',
    'MatmulOp',
    'SoftmaxOp',
    'ReshapeOp',
    'TransposeOp',
    'LessThanOp',
    'LayerNormOp',
    'AddGradOp',
    'ConcatOp',
    'CummaxOp',
    'CastOp',
    'ReluOp',
    'ReluGradOp',
    'BatchNorm_Op',
    'GeluOp',
    'GeluGradOp',
    'MatmulGradOp',
]

cache_grad_op_shape_black_list = {"fused_attention"}

# =====================================
# String Template for h file code gen
# =====================================
NAMESPACE_GUARD_TEMPLATE = """namespace {namespace} {{
{input}
}} // namespace {namespace}"""

H_FILE_TEMPLATE = """// This file is generated by "paddle/fluid/pir/dialect/op_generator/op_gen.py"
#pragma once
#include <vector>

#include "paddle/fluid/pir/dialect/operator/interface/decomp.h"
#include "paddle/fluid/pir/dialect/operator/interface/decomp_vjp.h"
#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/infer_symbolic_shape.h"
#include "paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/cache_grad_op_symbolic_shape.h"
#include "paddle/fluid/pir/dialect/operator/interface/infermeta.h"
#include "paddle/fluid/pir/dialect/operator/interface/layout_transformation.h"
#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h"
#include "paddle/fluid/pir/dialect/operator/interface/parse_kernel_key.h"
#include "paddle/fluid/pir/dialect/operator/interface/vjp.h"
#include "paddle/fluid/pir/dialect/operator/trait/inplace.h"
#include "paddle/fluid/pir/dialect/operator/trait/forward_only.h"
#include "paddle/fluid/pir/dialect/operator/utils/op_yaml_info_util.h"
#include "paddle/fluid/pir/dialect/operator/utils/utils.h"
#include "paddle/pir/include/core/builder.h"
#include "paddle/pir/include/core/op_base.h"
#include "paddle/pir/include/core/op_trait.h"
#include "paddle/pir/include/core/operation_utils.h"
#ifdef PADDLE_WITH_DNNL
#include "paddle/fluid/pir/dialect/operator/trait/onednn.h"
#endif
#include "paddle/fluid/ir_adaptor/translator/pd_op_sig.h"
#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h"
#include "paddle/fluid/pir/dialect/operator/trait/custom_vjp.h"
#include "paddle/phi/core/infermeta_utils.h"
{only_pd_op_header_files}

{other_info}

{input}

{declare_type_id}
"""

OP_TO_MULTI_KERNELS_MAP_H = """
extern std::unordered_map<std::string, std::vector<PdOpSig>> op_to_multi_kernels_map;
extern std::unordered_map<std::string, std::vector<PdOpSig>> sp_op_to_multi_kernels_map;
"""

ONEDNN_ONLY_OP_SET_H = """
extern std::set<std::string> onednn_only_op_set;
"""

GET_OP_LIST_TEMPLATE = """{}
"""

DECLARE_OP_TYPE_ID = """
IR_DECLARE_EXPLICIT_TYPE_ID({op_name})
"""

OP_DECLARE_TEMPLATE = """
class {TEST_API} {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{
 public:
  using Op::Op;
  static const char *name() {{ return "{dialect_op_name}"; }}
  {attribute_declare}
  static constexpr uint32_t attributes_num = {attribute_num};
  static OpInfoTuple GetOpInfo();
  static void Build({build_args});
  {build_mutable_attr_is_input}
  {build_attr_num_over_1}
  {build_mutable_attr_is_input_attr_num_over_1}
  void VerifySig();
{get_kernel_type_for_var_declare}
{parse_kernel_key_declare}
{infer_symbolic_shape_declare}
{cache_grad_op_symbolic_shape_declare}
{exclusive_interface}
}};
"""
op_0_attribute_declare_str = (
    "static constexpr const char **attributes_name = nullptr;"
)
op_n_attribute_declare_str = (
    "static const char *attributes_name[{attribute_num}];"
)

get_kernel_type_for_var_declare_template = """
  static phi::DataType GetKernelTypeForVar(
      const std::string& var_name,
        const phi::DataType& tensor_dtype,
        const phi::DataType& expected_kernel_dtype);
"""

parse_kernel_key_template = """
  static std::tuple<phi::DataType, phi::Backend> ParseKernelKey(pir::Operation *op);
"""

infer_symbolic_shape_template = """
  bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context);
"""

cache_grad_op_symbolic_shape_template = "  void CacheGradOpSymbolicShape(pir::InferSymbolicShapeContext* infer_context);"

# =====================================
# String Template for cc file code gen
# =====================================
CC_FILE_TEMPLATE = """// This file is generated by "paddle/fluid/pir/dialect/op_generator/op_gen.py"
#include "{h_file}"
#include "paddle/fluid/pir/dialect/kernel/ir/kernel_type.h"
#include "paddle/fluid/pir/dialect/operator/interface/layout_transformation.h"
#include "paddle/fluid/pir/dialect/operator/ir/ir_meta_tensor.h"
#include "paddle/fluid/pir/dialect/operator/ir/ir_selected_rows.h"
#include "paddle/fluid/pir/dialect/operator/ir/ir_sparse_tensor.h"
#include "paddle/fluid/pir/dialect/operator/ir/ir_tensor.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/phi/infermeta/spmd_rules/rules.h"
#include "paddle/fluid/pir/dialect/distributed/ir/dist_type.h"
#include "paddle/fluid/pir/dialect/distributed/ir/dist_tools.h"
#endif
#include "paddle/fluid/primitive/vjp_interface/vjp.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/infermeta/backward.h"
#include "paddle/phi/infermeta/binary.h"
#include "paddle/phi/infermeta/fusion.h"
#include "paddle/phi/infermeta/multiary.h"
#include "paddle/phi/infermeta/nullary.h"
#include "paddle/phi/infermeta/sparse/backward.h"
#include "paddle/phi/infermeta/sparse/binary.h"
#include "paddle/phi/infermeta/sparse/multiary.h"
#include "paddle/phi/infermeta/sparse/unary.h"
#include "paddle/phi/infermeta/ternary.h"
#include "paddle/phi/infermeta/unary.h"
#include "paddle/pir/include/core/builtin_attribute.h"
#include "paddle/pir/include/core/builtin_op.h"
#include "paddle/pir/include/core/builtin_type.h"
#include "paddle/pir/include/core/ir_context.h"
#include "paddle/pir/include/core/op_base.h"

using namespace paddle::dialect;

{input}

{define_type_id}
"""
# NOTE(cocoshe): Use CC_OP_INFO_FILE_TEMPLATE_WIN_PART1 to generate two GET_OP_LIST to avoid
# "fatal error C1202: recursive type or function dependency context too complex" error
# when compiling on vs2017 because the GET_OP_LIST is too long.
# And use CC_OP_INFO_FILE_TEMPLATE_PART1 to generate just one GET_OP_LIST for other compiler.
CC_OP_INFO_FILE_TEMPLATE_PART1 = """#ifdef GET_OP_LIST
#undef GET_OP_LIST
{op_declare}
"""

CC_OP_INFO_FILE_TEMPLATE_WIN_PART1 = """#ifdef GET_OP_LIST1
#undef GET_OP_LIST1
{op_declare_first_part}
#elif defined(GET_OP_LIST2)
#undef GET_OP_LIST2
{op_declare_second_part}
#elif defined(GET_OP_LIST3)
#undef GET_OP_LIST3
{op_declare_third_part}
#elif defined(GET_OP_LIST4)
#undef GET_OP_LIST4
{op_declare_fourth_part}
"""

CC_OP_INFO_FILE_TEMPLATE_PART2 = """
#else
// This file is generated by "paddle/fluid/pir/dialect/op_generator/op_gen.py"
#include "{h_file}"

{other_info}
#endif
"""

# =====================================
# String Template for pd_op_vjp.cc file code gen
# =====================================
VJP_CC_FILE_TEMPLATE = """// This file is generated by "paddle/fluid/pir/dialect/op_generator/op_gen.py"
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/dialect/operator/utils/utils.h"
#include "paddle/fluid/primitive/vjp_interface/vjp.h"
#include "paddle/fluid/primitive/base/lazy_tensor.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/pir/include/core/builtin_op.h"
#include "paddle/pir/include/core/op_base.h"

namespace paddle {{
namespace dialect {{
{input}
}}  // namespace dialect
}}  // namespace paddle
"""

OP_N_ATTRIBUTE_DEFINED_TEMPLATE = """
const char *{op_name}::attributes_name[{attribute_num}] = {{ {attribute_names} }};
"""

# get op info
OP_INFO_TEMPLATE = """
OpInfoTuple {op_name}::GetOpInfo() {{
  std::vector<paddle::dialect::OpInputInfo> inputs = {{ {inputs} }};
  std::vector<paddle::dialect::OpAttributeInfo> attributes = {{ {attributes} }};
  std::vector<paddle::dialect::OpOutputInfo> outputs = {{ {outputs} }};
  paddle::dialect::OpRunTimeInfo run_time_info = paddle::dialect::OpRunTimeInfo("{infer_meta_func}", {{"{infer_meta_param}"}}, "{kernel_func}", {{"{kernel_param}"}}, {{{kernel_key_dtype}}}, {{{kernel_key_backend}}}, {{{inplace}}}, {{{view}}});
  return std::make_tuple(inputs, attributes, outputs, run_time_info, "{origin_op_name}");
}}
"""

OP_INFO_ONEDNN_TEMPLATE = """
OpInfoTuple {op_name}::GetOpInfo() {{
  std::vector<paddle::dialect::OpInputInfo> inputs = {{ {inputs} }};
  std::vector<paddle::dialect::OpAttributeInfo> attributes = {{ {attributes} }};
  std::vector<paddle::dialect::OpOutputInfo> outputs = {{ {outputs} }};
  pir::AttributeMap extra_attr_default_value;
  {extra_attr_default_value_code}
  paddle::dialect::OpRunTimeInfo run_time_info = paddle::dialect::OpRunTimeInfo("{infer_meta_func}", {{"{infer_meta_param}"}}, "{kernel_func}", {{"{kernel_param}"}}, {{{kernel_key_dtype}}}, {{{kernel_key_backend}}}, {{{inplace}}}, {{{view}}}, {{{extra_args}}}, {{{skip_transform_inputs}}}, extra_attr_default_value, {{{data_format_tensors}}}, {is_onednn_only}, {dynamic_fallback});
  return std::make_tuple(inputs, attributes, outputs, run_time_info, "{origin_op_name}");
}}
"""

CONSTRUCT_INPUT_INFO_TEMPLATE = """paddle::dialect::OpInputInfo("{name}", "{typename}", {optional}, {no_need_buffer}, {is_mutable_attribute}, {with_grad_semantic})"""
CONSTRUCT_OUTPUT_INFO_TEMPLATE = """paddle::dialect::OpOutputInfo("{name}", "{typename}", {optional}, {intermediate})"""
CONSTRUCT_ATTRIBUTE_INFO_TEMPLATE = """paddle::dialect::OpAttributeInfo("{name}", "{typename}", "{data_type}")"""


DEFINE_OP_TYPE_ID = """
IR_DEFINE_EXPLICIT_TYPE_ID({op_name})
"""

OP_TO_MULTI_KERNELS_MAPS = """
std::unordered_map<std::string, std::vector<PdOpSig>> op_to_multi_kernels_map = {{
{maps}
}};
"""

SP_OP_TO_MULTI_KERNELS_MAPS = """
std::unordered_map<std::string, std::vector<PdOpSig>> sp_op_to_multi_kernels_map = {{
{maps}
}};
"""

ONEDNN_ONLY_OP_SET = """
std::set<std::string> onednn_only_op_set = {{
{maps}
}};
"""

scalar_type_maps = {
    'int': 'pir::Int32Attribute',
    'int64_t': 'pir::Int64Attribute',
    'float': 'pir::FloatAttribute',
    'double': 'pir::DoubleAttribute',
    'bool': 'pir::BoolAttribute',
}

PD_MANUAL_OP_LIST = {
    'add_n_',
    'split_grad',
    'expand',
    'increment',
    'increment_',
    'assign_out_',
}

ONEDNN_MANUAL_OP_LIST = {
    'split_grad',
    'expand',
}


class OpNamePair(NamedTuple):
    phi_name: str
    fluid_name: str


def to_phi_and_fluid_op_name(op_item: str) -> OpNamePair:
    # Template: - op : phi_name (fluid_name)
    names = op_item.split('(')
    if len(names) == 1:
        phi_fluid_name = names[0].strip()
        return OpNamePair(phi_fluid_name, phi_fluid_name)
    else:
        phi_name = names[0].strip()
        fluid_name = names[1].split(')')[0].strip()
        return OpNamePair(phi_name, fluid_name)


def to_phi_and_fluid_grad_op_name(op_item: str) -> list[OpNamePair]:
    # Template: sum_grad (reduce_sum_grad), sum_double_grad
    return list(map(to_phi_and_fluid_op_name, op_item.split(', ')))


# =====================================
# Parse Op Compat From Yaml
# =====================================
class OpCompatParser:
    def __init__(self, ops_compat_yaml_file: str):
        self.ops_compat_yaml_file = ops_compat_yaml_file
        with open(self.ops_compat_yaml_file, "r") as f:
            self.ops_compat = yaml.safe_load(f)

    def get_compat(self, op_name: str):
        for compat in self.ops_compat:
            name_pair = to_phi_and_fluid_op_name(compat['op'])
            if op_name == name_pair.phi_name:
                return compat
            elif 'backward' in compat.keys():
                bkw_names = to_phi_and_fluid_grad_op_name(compat['backward'])
                for name in bkw_names:
                    if op_name == name.phi_name:
                        return compat
        return None

    def parse_support_tensor(
        self, op
    ) -> tuple[dict[str, dict[str, bool]], dict[str, dict[str, bool]]]:
        scalar_item = {}
        int_array_item = {}
        for support_tensor_attr in op['support_tensor']:
            for attr in op['attrs']:
                if (
                    attr['typename'] == 'Scalar'
                    and attr['name'] == support_tensor_attr
                ):
                    scalar_item[support_tensor_attr] = {"support_tensor": True}
                if (
                    attr['typename'] == 'IntArray'
                    and attr['name'] == support_tensor_attr
                ):
                    scalar_item[support_tensor_attr] = {"support_tensor": True}
        return scalar_item, int_array_item


# =====================================
# Parse Op Information From Yaml
# =====================================
class OpInfoParser:
    def __init__(self, op_yaml_item, op_compat_item, yaml_file):
        self.op_yaml_item = op_yaml_item
        self.op_compat_item = op_compat_item
        self.yaml_file = yaml_file
        self.is_sparse_op = self.parse_op_type()
        self.is_fused_op = self.parse_fused_op_type()
        self.op_phi_name = self.parse_op_phi_name()
        self.class_name: str | None = None
        self.kernel_input_type_list: list[str] | None = None
        self.kernel_output_type_list: list[str] | None = None

        self.kernel_map = self.parse_kernel_map()

        # parse inputs
        self.input_name_list = self.parse_input_name_list()
        self.input_type_list = self.parse_input_type_list()
        self.input_type_dict = self.parse_input_type_dict()
        self.input_optional_list = self.parse_input_optional_list()
        self.input_no_need_buffer_list = self.parse_input_no_need_buffer_list()
        self.cross_check(
            self.input_name_list, self.input_type_list, self.input_optional_list
        )

        # parse outputs
        self.output_name_list = self.parse_output_name_list()
        self.output_type_list = self.parse_output_type_list()
        self.output_type_dict = self.parse_output_type_dict()
        self.output_size_list = self.parse_output_size_list()
        self.output_optional_list = self.parse_output_optional_list()
        self.output_intermediate_list = self.parse_output_intermediate_list()
        self.cross_check(
            self.output_name_list,
            self.output_type_list,
            self.output_optional_list,
        )

        # parse attributes
        self.attr_types_map = attr_types_map
        self.attribute_name_list = self.parse_attribute_name_list()
        self.attribute_type_list = self.parse_attribute_type_list()
        self.attribute_build_arg_type_list = (
            self.parse_attribute_build_arg_type_list()
        )
        self.attribute_gen_arg_type_list = (
            self.parse_attribute_gen_arg_type_list()
        )
        self.attribute_data_type_list = self.parse_attribute_data_type_list()
        self.attribute_default_value_list = (
            self.parse_attribute_default_value_list()
        )
        self.cross_check(self.attribute_name_list, self.attribute_type_list)

        # parse mutable attributes (as inputs)
        (
            self.mutable_attribute_name_list,
            self.mutable_attribute_type_list,
        ) = self.parse_mutable_attribute()

        (
            self.non_mutable_attribute_name_list,
            self.non_mutable_attribute_type_list,
            self.non_mutable_attribute_data_type_list,
            self.non_mutable_attribute_build_arg_type_list,
            self.non_mutable_attribute_default_value_list,
        ) = self.parse_non_mutable_attribute()

        # parse infermeta && kernel
        self.infer_meta_map = self.parse_infer_meta_map()
        self.invoke_map = self.parse_invoke_map()
        self.spmd_rule_func = None
        if 'infer_meta' in self.op_yaml_item:
            self.infer_meta_func = self.op_yaml_item['infer_meta']["func"]
            if 'spmd_rule' in self.op_yaml_item['infer_meta']:
                self.spmd_rule_func = self.op_yaml_item['infer_meta'][
                    'spmd_rule'
                ]
        else:
            self.infer_meta_func = None

        if 'infer_shaped_type_op_interface' in self.op_yaml_item:
            self.infer_shaped_type_op_interface_func = self.op_yaml_item[
                'infer_shaped_type_op_interface'
            ]["func"]
        else:
            self.infer_shaped_type_op_interface_func = None

        # parse backward name
        self.backward_name = self.parse_backward_name()

        # parse inplace && view
        self.inplace_map = self.parse_op_inplace_info()
        # sometime output is the inplace value of input, if input is optional, output should be optional too
        if self.inplace_map is not None:
            self.refine_output_optional_list(self.output_optional_list)
            self.cross_check(
                self.output_name_list,
                self.output_type_list,
                self.output_optional_list,
            )
        self.view_map = self.parse_op_view_info()

        # parse data_transform
        self.data_transform_map = self.parse_data_transform_info()

        # parse has_custom_verify
        self.custom_verify = self.parse_custom_verify()

        # parse forward input name list and attribute name list
        self.forward_input_name_list = self.parse_forward_input_name()

        # parse forward output name list
        self.forward_output_name_list = self.parse_forward_output_name()

        # parse traits list
        self.traits_list = self.parse_op_traits()

        # parse interfaces list
        self.interfaces_list = self.parse_op_interfaces()

        # OneDNN info
        if "extra_args" in self.op_yaml_item:
            self.onednn_extra_args = self.op_yaml_item["extra_args"]
        else:
            self.onednn_extra_args = []

        if "data_format_tensors" in self.op_yaml_item:
            self.onednn_data_format_tensors = self.op_yaml_item[
                "data_format_tensors"
            ]
        else:
            self.onednn_data_format_tensors = None

        if "is_onednn_only" in self.op_yaml_item:
            self.is_onednn_only = self.op_yaml_item["is_onednn_only"]
        else:
            self.is_onednn_only = False

        if "dynamic_fallback" in self.op_yaml_item:
            self.dynamic_fallback = self.op_yaml_item["dynamic_fallback"]
        else:
            self.dynamic_fallback = False

    def parse_op_traits(self):
        if 'traits' in self.op_yaml_item:
            return self.op_yaml_item['traits']
        else:
            return []

    def parse_op_interfaces(self):
        if 'interfaces' in self.op_yaml_item:
            return self.op_yaml_item['interfaces']
        else:
            return []

    def parse_forward_input_name(self):
        if 'forward' in self.op_yaml_item:
            forward_input_name_list = []
            forward_map = self.op_yaml_item['forward']
            if forward_map is not None:
                inputs = forward_map['inputs']
                for input in inputs:
                    forward_input_name_list.append(input['name'])
                return forward_input_name_list
            else:
                return None
        else:
            return None

    def parse_forward_output_name(self):
        if 'forward' in self.op_yaml_item:
            forward_output_name_list = []
            forward_map = self.op_yaml_item['forward']
            if forward_map is not None:
                outputs = forward_map['outputs']
                for output in outputs:
                    forward_output_name_list.append(output['name'])
                return forward_output_name_list
            else:
                return None
        else:
            return None

    def cross_check(self, name_list, type_list, optional_list=None):
        assert len(name_list) == len(
            type_list
        ), "name list size != type list size."
        if optional_list is not None:
            assert len(name_list) == len(
                optional_list
            ), "type list size != optional list size."

    def parse_custom_verify(self):
        if 'custom_verify' in self.op_yaml_item:
            return self.op_yaml_item['custom_verify']
        return False

    def parse_op_phi_name(self):
        if (self.parse_op_inplace_info() is None) and (
            self.parse_op_view_info() is None
        ):
            return [self.op_yaml_item['name']]
        else:
            if self.op_yaml_item['name'][-1] == "_":
                return [self.op_yaml_item['name']]
            else:
                return [
                    self.op_yaml_item['name'],
                    self.op_yaml_item['name'] + "_",
                ]

    def parse_op_inplace_info(self):
        if 'inplace' in self.op_yaml_item:
            return self.op_yaml_item['inplace']
        return None

    def parse_op_view_info(self):
        if 'view' in self.op_yaml_item:
            return self.op_yaml_item['view']
        return None

    def is_mutable_attribute(self, attr_dict):
        if (
            'support_tensor' in attr_dict
            and attr_dict['support_tensor'] is True
        ):
            return True
        elif 'tensor_name' in attr_dict or 'tensors_name' in attr_dict:
            return True
        else:
            return False

    def parse_mutable_attribute(self):
        """
        {'axis': 'paddle::dialect::ScalarAttribute', 'rotl': 'paddle::dialect::IntArrayAttribute'}
        """
        mutable_attribute_name_list = []
        mutable_attribute_type_list = []
        # scalar
        if (
            (self.op_compat_item is not None)
            and ('scalar' in self.op_compat_item)
            and not self.yaml_file.endswith("sparse_ops.parsed.yaml")
        ):
            for scalar_attr in self.op_compat_item['scalar'].keys():
                if not self.is_mutable_attribute(
                    self.op_compat_item['scalar'][scalar_attr]
                ):
                    continue
                if 'data_type' in self.op_compat_item['scalar'][scalar_attr]:
                    if (
                        scalar_attr == "depth"
                        and self.op_phi_name[0] == "one_hot"
                    ):
                        mutable_attribute_name_list.append("num_classes")
                    else:
                        mutable_attribute_name_list.append(scalar_attr)
                    data_type = self.op_compat_item['scalar'][scalar_attr][
                        'data_type'
                    ]
                    # patch for isclose and allclose
                    if (self.op_compat_item['op'] == "isclose") or (
                        self.op_compat_item['op'] == "allclose"
                    ):
                        data_type = "double"
                    mutable_attribute_type_list.append(
                        [
                            "paddle::dialect::ScalarAttribute",
                            data_type,
                        ]
                    )
                # See eye in op_compat.yaml
                else:
                    mutable_attribute_name_list.append(scalar_attr)
                    mutable_attribute_type_list.append(
                        [
                            "paddle::dialect::ScalarAttribute",
                            self.attribute_data_type_list[
                                self.attribute_name_list.index(scalar_attr)
                            ],
                        ]
                    )
        # int_array
        if (self.op_compat_item is not None) and (
            'int_array' in self.op_compat_item
        ):
            for int_array_attr in self.op_compat_item['int_array']:
                if not self.is_mutable_attribute(
                    self.op_compat_item['int_array'][int_array_attr]
                ):
                    continue
                mutable_attribute_name_list.append(int_array_attr)
                mutable_attribute_type_list.append(
                    [
                        "paddle::dialect::IntArrayAttribute",
                        self.op_compat_item['int_array'][int_array_attr][
                            'data_type'
                        ],
                    ]
                )
        sorted_mutable_attribute_name_list = []
        sorted_mutable_attribute_type_list = []
        for attr_name in self.attribute_name_list:
            if attr_name in mutable_attribute_name_list:
                sorted_mutable_attribute_name_list.append(attr_name)
                sorted_mutable_attribute_type_list.append(
                    mutable_attribute_type_list[
                        mutable_attribute_name_list.index(attr_name)
                    ]
                )

        return (
            sorted_mutable_attribute_name_list,
            sorted_mutable_attribute_type_list,
        )

    def parse_non_mutable_attribute(self):
        op_non_mutable_attribute_name_list = []
        op_non_mutable_attribute_type_list = []
        op_non_mutable_attribute_data_type_list = []
        op_non_mutable_attribute_build_arg_type_list = []
        op_non_mutable_attribute_default_value_list = []
        for idx in range(len(self.attribute_name_list)):
            if (
                self.attribute_name_list[idx]
                not in self.mutable_attribute_name_list
            ):
                op_non_mutable_attribute_name_list.append(
                    self.attribute_name_list[idx]
                )
                op_non_mutable_attribute_type_list.append(
                    self.attribute_type_list[idx]
                )
                op_non_mutable_attribute_data_type_list.append(
                    self.attribute_data_type_list[idx]
                )
                op_non_mutable_attribute_build_arg_type_list.append(
                    self.attribute_build_arg_type_list[idx]
                )
                op_non_mutable_attribute_default_value_list.append(
                    self.attribute_default_value_list[idx]
                )
        return (
            op_non_mutable_attribute_name_list,
            op_non_mutable_attribute_type_list,
            op_non_mutable_attribute_data_type_list,
            op_non_mutable_attribute_build_arg_type_list,
            op_non_mutable_attribute_default_value_list,
        )

    def parse_op_type(self) -> bool:
        if self.yaml_file.endswith(
            "sparse_ops.parsed.yaml"
        ) or self.yaml_file.endswith("sparse_backward.parsed.yaml"):
            return True
        else:
            return False

    def parse_fused_op_type(self) -> bool:
        if self.yaml_file.endswith(
            "fused_ops.parsed.yaml"
        ) or self.yaml_file.endswith("fused_backward.parsed.yaml"):
            return True
        else:
            return False

    def parse_input_name_list(self):
        name_list = []
        for input_info in self.op_yaml_item['inputs']:
            name_list.append(input_info['name'])
        return name_list

    def parse_input_type_list(self):
        input_types_map = {
            'Tensor': 'paddle::dialect::DenseTensorType',
            'Tensor[]': 'pir::VectorType<paddle::dialect::DenseTensorType>',
        }
        type_list = []
        for input_info in self.op_yaml_item['inputs']:
            assert (
                input_info['typename'] in input_types_map
            ), f"{self.op_phi_name} : Input type error: the input type only support Tensor and Tensor[], but now is {input_info['typename']}."
            type_list.append(input_types_map[input_info['typename']])
        return type_list

    def parse_input_type_dict(self):
        type_dict = {}
        if (
            self.kernel_map is None
            or self.kernel_map['dispatch'][self.kernel_map['func'][0]] is None
        ):
            input_types_map = {
                'Tensor': 'paddle::dialect::DenseTensorType',
                'Tensor[]': 'pir::VectorType<paddle::dialect::DenseTensorType>',
            }
            type_list = []
            for input_info in self.op_yaml_item['inputs']:
                assert (
                    input_info['typename'] in input_types_map
                ), f"{self.op_phi_name} : Input type error: the input type only support Tensor and Tensor[], but now is {input_info['typename']}."
                type_list.append(input_types_map[input_info['typename']])

            if self.kernel_map is None:
                type_dict['default'] = type_list
            else:
                type_dict[self.kernel_map['func'][0]] = type_list

        else:
            input_types_map = {
                'dense': 'paddle::dialect::DenseTensorType',
                'selected_rows': 'paddle::dialect::SelectedRowsType',
                'sparse_coo': 'paddle::dialect::SparseCooTensorType',
                'sparse_csr': 'paddle::dialect::SparseCsrTensorType',
            }

            for kernel_func_name in self.kernel_map['func']:
                inputs = self.kernel_map['dispatch'][kernel_func_name][0]
                type_list = []
                for input_info in inputs:
                    assert (
                        input_info in input_types_map
                    ), f"{self.op_phi_name} : Input type error: the input type only support dense and selected_rows, but now is {input_info}."
                    type_list.append(input_types_map[input_info])

                type_dict[kernel_func_name] = type_list
        return type_dict

    def parse_input_optional_list(self):
        optional_list = []
        for input_info in self.op_yaml_item['inputs']:
            if input_info['optional']:
                optional_list.append("true")
            else:
                optional_list.append("false")
        return optional_list

    def parse_input_no_need_buffer_list(self):
        no_need_buffer_list = []
        for input_info in self.op_yaml_item['inputs']:
            if input_info['no_need_buffer']:
                no_need_buffer_list.append("true")
            else:
                no_need_buffer_list.append("false")
        return no_need_buffer_list

    def parse_output_name_list(self):
        name_list = []
        for output_info in self.op_yaml_item['outputs']:
            name_list.append(output_info['name'])
        return name_list

    def parse_output_type_list(self):
        output_type_map = {
            'Tensor': 'paddle::dialect::DenseTensorType',
            'Tensor[]': 'pir::VectorType<paddle::dialect::DenseTensorType>',
            'SelectedRows': 'paddle::dialect::SelectedRowsType',
        }
        type_list = []
        for output_info in self.op_yaml_item['outputs']:
            assert (
                output_info['typename'] in output_type_map
            ), f"{self.op_phi_name} : Output type error: the output type only support Tensor and Tensor[], but now is {output_info['typename']}."
            type_list.append(output_type_map[output_info['typename']])
        return type_list

    def parse_output_type_dict(self):
        type_dict = {}

        if (
            self.kernel_map is None
            or self.kernel_map['dispatch'][self.kernel_map['func'][0]] is None
        ):
            output_type_map = {
                'Tensor': 'paddle::dialect::DenseTensorType',
                'Tensor[]': 'pir::VectorType<paddle::dialect::DenseTensorType>',
                'SelectedRows': 'paddle::dialect::SelectedRowsType',
            }
            type_list = []
            for output_info in self.op_yaml_item['outputs']:
                assert (
                    output_info['typename'] in output_type_map
                ), f"{self.op_phi_name} : Output type error: the output type only support Tensor and Tensor[], but now is {output_info['typename']}."
                type_list.append(output_type_map[output_info['typename']])

            if self.kernel_map is None:
                type_dict['default'] = type_list
            else:
                type_dict[self.kernel_map['func'][0]] = type_list

        else:
            output_type_map = {
                'dense': 'paddle::dialect::DenseTensorType',
                'selected_rows': 'paddle::dialect::SelectedRowsType',
                'sparse_coo': 'paddle::dialect::SparseCooTensorType',
                'sparse_csr': 'paddle::dialect::SparseCsrTensorType',
            }

            for kernel_func_name in self.kernel_map['func']:
                outputs = self.kernel_map['dispatch'][kernel_func_name][1]
                type_list = []
                for output_info in outputs:
                    assert (
                        output_info in output_type_map
                    ), f"{self.op_phi_name} : Input type error: the input type only support dense and selected_rows, but now is {output_info}."
                    type_list.append(output_type_map[output_info])

                type_dict[kernel_func_name] = type_list

        return type_dict

    def parse_output_size_list(self):
        size_list = []
        for output_info in self.op_yaml_item['outputs']:
            if 'size' in output_info:
                size_list.append(output_info['size'])
            else:
                size_list.append(None)
        return size_list

    def parse_output_optional_list(self):
        optional_list = []
        for output_info in self.op_yaml_item['outputs']:
            if 'optional' in output_info or 'intermediate' in output_info:
                if output_info['optional'] or output_info['intermediate']:
                    optional_list.append("true")
                else:
                    optional_list.append("false")
            else:
                optional_list.append("false")
        return optional_list

    def refine_output_optional_list(self, optional_list):
        for i, output_info in enumerate(self.op_yaml_item['outputs']):
            if output_info['name'] in self.inplace_map.keys():
                input_index = self.input_name_list.index(
                    self.inplace_map[output_info['name']]
                )
                if self.input_optional_list[input_index] == "true":
                    optional_list[i] = "true"

    def parse_output_intermediate_list(self):
        intermediate_list = []
        for output_info in self.op_yaml_item['outputs']:
            if 'intermediate' in output_info:
                if output_info['intermediate']:
                    intermediate_list.append("true")
                else:
                    intermediate_list.append("false")
            else:
                intermediate_list.append("false")
        return intermediate_list

    def parse_attribute_name_list(self):
        name_list = []
        for attribute_info in self.op_yaml_item['attrs']:
            name_list.append(attribute_info['name'])
        return name_list

    def parse_attribute_build_arg_type_list(self):
        type_list = []
        for attribute_info in self.op_yaml_item['attrs']:
            assert (
                attribute_info['typename'] in self.attr_types_map
            ), f"{self.op_phi_name} : Attr type error."

            # Scalar & IntArray has data_type
            temp_type = self.attr_types_map[attribute_info['typename']][1]
            if 'Scalar' in temp_type:
                if 'data_type' in attribute_info:
                    temp_type = attribute_info['data_type']
                op_name = self.op_yaml_item['name']
                attr_name = attribute_info['name']
                if (
                    op_name not in ["isclose", "allclose"]
                    and self.op_compat_item is not None
                    and 'scalar' in self.op_compat_item.keys()
                    and attr_name in self.op_compat_item['scalar'].keys()
                    and 'data_type'
                    in self.op_compat_item['scalar'][attr_name].keys()
                ):
                    temp_type = self.op_compat_item['scalar'][attr_name][
                        'data_type'
                    ]
            if 'IntArray' in temp_type:
                if 'data_type' in attribute_info:
                    temp_type = "const " + attribute_info['data_type'] + "&"
            type_list.append(self.get_phi_dtype_name(temp_type))
        return type_list

    def parse_attribute_gen_arg_type_list(self):
        type_list = []
        for attribute_info in self.op_yaml_item['attrs']:
            assert (
                attribute_info['typename'] in self.attr_types_map
            ), f"{self.op_phi_name} : Attr type error."

            temp_type = self.attr_types_map[attribute_info['typename']][1]
            type_list.append(self.get_phi_dtype_name(temp_type))
        return type_list

    def parse_attribute_type_list(self):
        type_list = []
        for attribute_info in self.op_yaml_item['attrs']:
            assert (
                attribute_info['typename'] in self.attr_types_map
            ), f"{self.op_phi_name} : Attr type error."
            type_list.append(self.attr_types_map[attribute_info['typename']][0])
        return type_list

    def parse_attribute_data_type_list(self):
        data_type_list = []
        for attribute_info in self.op_yaml_item['attrs']:
            if 'data_type' in attribute_info:
                data_type_list.append(attribute_info['data_type'])
            else:
                data_type_list.append("")
        return data_type_list

    def parse_attribute_default_value_list(self):
        default_value_list = []
        for attribute_info in self.op_yaml_item['attrs']:
            if 'default_value' in attribute_info:
                default_value = attribute_info['default_value']
                default_value_list.append(
                    self.get_phi_dtype_name(default_value)
                )
            else:
                default_value_list.append(None)
        return default_value_list

    def parse_infer_meta_map(self):
        if 'infer_meta' in self.op_yaml_item:
            return self.op_yaml_item['infer_meta']
        else:
            return None

    def parse_kernel_map(self):
        if 'kernel' in self.op_yaml_item:
            return self.op_yaml_item['kernel']
        else:
            return None

    def parse_invoke_map(self):
        if 'invoke' in self.op_yaml_item:
            return self.op_yaml_item['invoke']
        else:
            return None

    def parse_data_transform_info(self):
        if self.op_yaml_item.get('data_transform'):
            data_trans_item = self.op_yaml_item['data_transform']
            return data_trans_item
        return None

    def parse_backward_name(self):
        if 'backward' in self.op_yaml_item:
            return self.op_yaml_item['backward']
        else:
            return None

    def get_phi_dtype_name(self, name: str):
        name = name.replace('Scalar', 'phi::Scalar')
        name = name.replace('IntArray', 'phi::IntArray')
        name = name.replace('DataLayout', 'phi::DataLayout')
        name = name.replace('DataType', 'phi::DataType')
        if name.startswith(
            (
                "Place",
                "CPUPlace",
                "GPUPlace",
                "GPUPinnedPlace",
                "XPUPlace",
                "IPUPlace",
                "CustomPlace",
            )
        ):
            return "phi::" + name
        return name


def get_input_grad_semantic(
    op_info: OpInfoParser, op_info_items: dict[str, OpInfoParser]
):
    input_grad_semantics = []
    num_inputs = len(op_info.input_name_list)

    # get backward op
    bwd_op_name = op_info.backward_name
    sparse_op_name_suffix = '_sp' if op_info.is_sparse_op else ''
    if (bwd_op_name is None) or (bwd_op_name not in op_info_items.keys()):
        input_grad_semantics = ["false" for i in range(num_inputs)]
    else:
        bwd_op_info = op_info_items[bwd_op_name + sparse_op_name_suffix]

        # cut "_grad" of each output of bwd_op, and then compare each modified output with corresponding input
        # thus determine whether each input has grad semantic
        bwd_output_list = bwd_op_info.output_name_list
        bwd_output_list_new = []
        for bwd_output in bwd_output_list:
            bwd_output_list_new.append(bwd_output[:-5])  # cut _grad

        bwd_fwd_input_list = bwd_op_info.forward_input_name_list
        if bwd_fwd_input_list is not None:
            assert (
                len(bwd_fwd_input_list) == num_inputs
            ), "Configuration of forward op and backward op is not match."
            for i in range(num_inputs):
                if bwd_fwd_input_list[i] in bwd_output_list_new:
                    input_grad_semantics.append("true")
                else:
                    input_grad_semantics.append("false")
        else:
            input_grad_semantics = ["false" for i in range(num_inputs)]

    return input_grad_semantics


def get_mutable_attribute_grad_semantic(
    op_info: OpInfoParser, op_info_items: dict[str, OpInfoParser]
):
    mutable_attribute_grad_semantics = []
    fwd_mutable_attribute_list = op_info.mutable_attribute_name_list

    # get backward op
    bwd_op_name = op_info.backward_name
    sparse_op_name_suffix = '_sp' if op_info.is_sparse_op else ''
    if (bwd_op_name is None) or (bwd_op_name not in op_info_items.keys()):
        mutable_attribute_grad_semantics = [
            "false" for i in range(len(fwd_mutable_attribute_list))
        ]
    else:
        bwd_op_info = op_info_items[bwd_op_name + sparse_op_name_suffix]

        # cut "_grad" of each output of bwd_op, and then compare each modified output with corresponding attribute
        # thus determine whether each attribute has grad semantic
        bwd_output_list = bwd_op_info.output_name_list
        bwd_output_list_new = []
        for bwd_output in bwd_output_list:
            bwd_output_list_new.append(bwd_output[:-5])

        for i in range(len(fwd_mutable_attribute_list)):
            if fwd_mutable_attribute_list[i] in bwd_output_list_new:
                mutable_attribute_grad_semantics.append("true")
            else:
                mutable_attribute_grad_semantics.append("false")

    return mutable_attribute_grad_semantics


def split_ops(op_info_items: dict[str, Any], cc_file: str, split_nums: int):
    op_list = list(op_info_items.keys())
    ops_max_size = math.ceil(len(op_list) / split_nums)
    split_op_info_items = []
    for i in range(split_nums):
        split_op_info_items.append({})
    for i, op_name in enumerate(op_list):
        list_idx = math.ceil((i + 1) / ops_max_size) - 1
        split_op_info_items[list_idx][op_name] = op_info_items[op_name]
    split_cc_files = []
    for i in range(split_nums):
        split_cc_files.append(cc_file.replace(".cc", f"{i + 1}.cc"))
    return split_op_info_items, split_cc_files


def GenOneDnnExtraAttrsDefaultValue(onednn_extra_args):
    INTARRAY_STR_TEMPLATE = """  pir::Attribute attr_{attr_name} = {op_attribute_type}::get(pir::IrContext::Instance(), phi::IntArray({attr}));
"""
    SCALAR_STR_TEMPLATE = """  pir::Attribute attr_{attr_name} = paddle::dialect::TransToIrAttribute({attr}, pir::IrContext::Instance());
"""
    STR_TEMPLATE = """  pir::Attribute attr_{attr_name} = {op_attribute_type}::get(pir::IrContext::Instance(), {attr});
"""
    ARRAY_ATTRIBUTE_TEMPLATE = """  std::vector<pir::Attribute> vec_{attr_name};
{{
    std::vector<{cpp_type}> vec_values = {attr_values};
    for (size_t i = 0; i < static_cast<size_t>(vec_values.size()); i++) {{
        {create_attribute}
        vec_{attr_name}.push_back(attr_{attr_name});
    }}
}}
pir::Attribute attr_{attr_name} = pir::ArrayAttribute::get(pir::IrContext::Instance(), vec_{attr_name});
"""
    attr_str = ""
    array_attr_type = "pir::ArrayAttribute<"
    for idx in range(len(onednn_extra_args)):
        assert (
            onednn_extra_args[idx]['typename'] in attr_types_map
        ), f"{onednn_extra_args[idx]['typename']} : Attr type error."
        extra_arg_type = attr_types_map[onednn_extra_args[idx]['typename']][0]

        if array_attr_type in extra_arg_type:
            inner_attribute_type = extra_arg_type[len(array_attr_type) : -1]
            if inner_attribute_type == "paddle::dialect::IntArrayAttribute":
                attr_str += ARRAY_ATTRIBUTE_TEMPLATE.format(
                    attr_name=onednn_extra_args[idx]['name'],
                    cpp_type=onednn_extra_args[idx]['typename'].replace(
                        '[]', ''
                    ),
                    attr_values=onednn_extra_args[idx]['default_value'],
                    create_attribute=INTARRAY_STR_TEMPLATE.format(
                        attr_name=onednn_extra_args[idx]['name'],
                        op_attribute_type=inner_attribute_type,
                        attr="vec_values[i]",
                    ),
                )
            elif inner_attribute_type == "paddle::dialect::ScalarAttribute":
                attr_str += ARRAY_ATTRIBUTE_TEMPLATE.format(
                    attr_name=onednn_extra_args[idx]['name'],
                    cpp_type=onednn_extra_args[idx]['typename'].replace(
                        '[]', ''
                    ),
                    attr_values=onednn_extra_args[idx]['default_value'],
                    create_attribute=SCALAR_STR_TEMPLATE.format(
                        attr_name=onednn_extra_args[idx]['name'],
                        attr="vec_values[i]",
                    ),
                )
            else:
                attr_str += ARRAY_ATTRIBUTE_TEMPLATE.format(
                    attr_name=onednn_extra_args[idx]['name'],
                    cpp_type=onednn_extra_args[idx]['typename'].replace(
                        '[]', ''
                    ),
                    attr_values=onednn_extra_args[idx]['default_value'],
                    create_attribute=STR_TEMPLATE.format(
                        attr_name=onednn_extra_args[idx]['name'],
                        op_attribute_type=inner_attribute_type,
                        attr="vec_values[i]",
                    ),
                )
        elif extra_arg_type == "paddle::dialect::IntArrayAttribute":
            attr_str += INTARRAY_STR_TEMPLATE.format(
                attr_name=onednn_extra_args[idx]['name'],
                op_attribute_type=extra_arg_type,
                attr=onednn_extra_args[idx]['name'],
            )
        elif extra_arg_type == "paddle::dialect::ScalarAttribute":
            attr_str += SCALAR_STR_TEMPLATE.format(
                attr_name=onednn_extra_args[idx]['name'],
                attr=onednn_extra_args[idx]['name'],
            )
        else:
            attr_str += STR_TEMPLATE.format(
                attr_name=onednn_extra_args[idx]['name'],
                op_attribute_type=extra_arg_type,
                attr=onednn_extra_args[idx]['default_value'],
            )

        attr_str += """extra_attr_default_value["{attr_name}"] = attr_{attr_name};\n""".format(
            attr_name=onednn_extra_args[idx]['name']
        )

    return attr_str


def AutoCodeGen(
    args: argparse.Namespace,
    op_info_items: dict[str, OpInfoParser],
    all_op_info_items: dict[str, OpInfoParser],
    namespaces: list[str],
    dialect_name: str,
):
    # (3) CodeGen: Traverse op_info_items and generate
    ops_name_list = []  # all op class name store in this list
    ops_declare_list = []  # all op class declare store in this list
    ops_defined_list = []  # all op class defined store in this list
    ops_vjp_defined_list = []  # all op vjp static interface definition

    # (4) parse name of ops which have custom vjp rules
    custom_vjp_op_name_list = []

    for custom_vjp in vjp_gen.CUSTOM_VJP:
        custom_vjp_op_name_list.append(custom_vjp[:-5])  # cut _grad

    op_to_multi_kernels_list = []
    sp_op_to_multi_kernels_list = []
    for key, op_info in op_info_items.items():
        if key == "add_n_grad":
            continue
        # get op inputs info
        op_input_name_list = op_info.input_name_list
        op_input_optional_list = op_info.input_optional_list
        op_input_no_need_buffer_list = op_info.input_no_need_buffer_list
        # get op outputs info
        op_output_name_list = op_info.output_name_list
        op_output_size_list = op_info.output_size_list
        op_output_optional_list = op_info.output_optional_list
        op_output_intermediate_list = op_info.output_intermediate_list
        # get op mutable attribute
        op_mutable_attribute_name_list = op_info.mutable_attribute_name_list
        op_mutable_attribute_type_list = op_info.mutable_attribute_type_list
        # get op attribute
        op_attribute_name_list = op_info.attribute_name_list
        op_attribute_type_list = op_info.attribute_type_list
        op_attribute_data_type_list = op_info.attribute_data_type_list
        op_attribute_build_arg_type_list = op_info.attribute_build_arg_type_list
        op_attribute_default_value_list = op_info.attribute_default_value_list
        op_non_mutable_attribute_name_list = (
            op_info.non_mutable_attribute_name_list
        )
        op_non_mutable_attribute_type_list = (
            op_info.non_mutable_attribute_type_list
        )
        op_non_mutable_attribute_data_type_list = (
            op_info.non_mutable_attribute_data_type_list
        )
        op_non_mutable_attribute_build_arg_type_list = (
            op_info.non_mutable_attribute_build_arg_type_list
        )
        op_non_mutable_attribute_default_value_list = (
            op_info.non_mutable_attribute_default_value_list
        )

        # others
        op_infer_meta_map = op_info.infer_meta_map
        op_kernel_map = op_info.kernel_map
        op_invoke_map = op_info.invoke_map
        op_inplace_map = op_info.inplace_map
        op_view_map = op_info.view_map
        op_data_transform_map = op_info.data_transform_map
        op_traits = op_info.traits_list
        op_interfaces = op_info.interfaces_list
        op_interfaces += ["paddle::dialect::OpYamlInfoInterface"]
        if (
            dialect_name == "pd_op"
            and op_info.backward_name
            and not op_info.is_sparse_op
            and all_op_info_items[op_info.backward_name].kernel_map is not None
            and op_info.op_phi_name[0] not in cache_grad_op_shape_black_list
        ):
            op_interfaces += [
                "paddle::dialect::CacheGradOpSymbolicShapeInterface"
            ]
        exclusive_interface_str = gen_exclusive_interface_str(
            op_info, op_info_items
        )

        if dialect_name == "pd_op" or dialect_name == "onednn_op":
            op_interfaces += ["paddle::dialect::GetKernelTypeForVarInterface"]

        # if op has custom vjp rule, then append a CustomVjpTrait to it
        if (
            op_info.op_phi_name[0] in custom_vjp_op_name_list
            and dialect_name != "onednn_op"
        ):
            op_traits += ["paddle::dialect::CustomVjpTrait"]

        # check op inputs and mutable_attributes grad semantics
        input_grad_semantics = get_input_grad_semantic(
            op_info, all_op_info_items
        )
        mutable_attribute_grad_semantics = get_mutable_attribute_grad_semantic(
            op_info, all_op_info_items
        )
        op_interfaces_tmp = op_interfaces
        exclusive_interface_str_tmp = exclusive_interface_str
        decomp_interface_str = "paddle::dialect::DecompInterface"
        decomp_interface_declare_str = "\n  static std::vector<std::vector<pir::Value>> Decomp(pir::Operation* op);"
        decomp_vjp_interface_str = "paddle::dialect::DecompVjpInterface"
        decomp_vjp_interface_declare_str = "\n  static std::vector<std::vector<pir::Value>> DecompVjp(pir::Operation* op);"

        # If op has inplace info, we will generate inplace op and non-inplace op.
        for op_name in op_info.op_phi_name:
            # =================================== #
            #        gen trait list str           #
            # =================================== #
            if op_name[-1] == "_":
                op_traits += ["paddle::dialect::InplaceTrait"]

            if dialect_name == "onednn_op":
                op_traits += ["paddle::dialect::OneDNNTrait"]

            if op_info.is_onednn_only:
                op_traits += ["paddle::dialect::OneDNNOnlyTrait"]

            if op_info.dynamic_fallback:
                op_traits += ["paddle::dialect::OneDNNDynamicFallbackTrait"]

            op_traits_str = ""
            if len(op_traits) > 0:
                op_traits_str = "," + ",".join(op_traits)

            if dialect_name == "onednn_op" and op_name in ONEDNN_MANUAL_OP_LIST:
                continue
            elif dialect_name != "onednn_op" and op_name in PD_MANUAL_OP_LIST:
                continue
            if op_kernel_map is None:
                func_list = [None]
            else:
                func_list = op_kernel_map['func']

            for kernel_func_name in func_list:
                if (
                    op_name in decomp_rule_interface_declare_gen_op_list
                    and kernel_func_name
                    in decomp_rule_interface_declare_gen_op_list
                    and dialect_name != "onednn_op"
                ):
                    if decomp_interface_str not in op_interfaces:
                        op_interfaces = [*op_interfaces, decomp_interface_str]
                    if (
                        decomp_interface_declare_str
                        not in exclusive_interface_str
                    ):
                        exclusive_interface_str += decomp_interface_declare_str
                elif (
                    op_name in decomp_vjp_interface_declare_gen_op_list
                    and kernel_func_name
                    in decomp_vjp_interface_declare_gen_op_list
                    and dialect_name != "onednn_op"
                ):
                    if decomp_vjp_interface_str not in op_interfaces:
                        op_interfaces = [
                            *op_interfaces,
                            decomp_vjp_interface_str,
                        ]
                    if (
                        decomp_vjp_interface_declare_str
                        not in exclusive_interface_str
                    ):
                        exclusive_interface_str += (
                            decomp_vjp_interface_declare_str
                        )
                else:
                    op_interfaces = op_interfaces_tmp
                    exclusive_interface_str = exclusive_interface_str_tmp

                # =================================== #
                #      gen interface list str         #
                # =================================== #
                op_class_name_suffix = 'Sp' if op_info.is_sparse_op else ''
                op_dialect_name_suffix = '_sp' if op_info.is_sparse_op else ''
                # kernel_func_name_inplace_suffix = (
                #     '_SpOp' if op_info.is_sparse_op else '_Op'
                # )
                op_dialect_name_inplace_suffix = (
                    'sp_' if op_info.is_sparse_op else ''
                )
                if len(func_list) == 1:
                    op_class_name = (
                        to_pascal_case(op_name) + op_class_name_suffix + "Op"
                    )
                    if op_name[-1] == "_":
                        op_dialect_name = (
                            dialect_name
                            + "."
                            + op_name
                            + op_dialect_name_inplace_suffix
                        )
                    else:
                        op_dialect_name = (
                            dialect_name
                            + "."
                            + op_name
                            + op_dialect_name_suffix
                        )
                else:
                    pascal_kernel_func_name = to_pascal_case(kernel_func_name)
                    if op_name[-1] == "_":
                        op_class_name = (
                            pascal_kernel_func_name
                            + op_class_name_suffix
                            + '_Op'
                        )
                        op_dialect_name = (
                            dialect_name
                            + "."
                            + kernel_func_name  # type: ignore
                            + "_"
                            + op_dialect_name_inplace_suffix
                        )
                    else:
                        op_class_name = (
                            pascal_kernel_func_name
                            + op_class_name_suffix
                            + "Op"
                        )
                        op_dialect_name = (
                            dialect_name
                            + "."
                            + kernel_func_name  # type: ignore
                            + op_dialect_name_suffix
                        )
                if kernel_func_name is None:
                    op_input_type_list = op_info.input_type_dict['default']
                    op_output_type_list = op_info.output_type_dict['default']
                else:
                    op_input_type_list = op_info.input_type_dict[
                        kernel_func_name
                    ]
                    op_output_type_list = op_info.output_type_dict[
                        kernel_func_name
                    ]

                op_info.class_name = op_class_name
                op_info.kernel_input_type_list = op_input_type_list

                op_info.kernel_output_type_list = op_output_type_list

                (
                    all_interface_list,
                    exclusive_declare_list,
                    exclusive_impl_list,
                ) = gen_op_all_func(args, op_info, op_info_items)
                all_interface_list += op_interfaces

                all_interface_str = ""
                if len(all_interface_list) > 0:
                    all_interface_str = "," + ",".join(all_interface_list)

                all_declare_str = (
                    exclusive_interface_str
                    + '\n'
                    + '\n'.join(exclusive_declare_list)
                )
                ops_defined_list += exclusive_impl_list

                # =================================== #
                #         gen Build methods str       #
                # =================================== #
                build_args_with_muta_attr_not_input_for_declare = ""
                build_func_with_muta_attr_not_input = ""
                build_mutable_attr_is_input = ""
                build_func_with_muta_attr_is_input_with_attr_is_map = ""
                build_attr_num_over_1 = ""
                build_mutable_attr_is_input_attr_num_over_1 = ""
                build_func_with_attr_is_map = ""
                build_func_with_muta_attr_is_input = ""

                get_kernel_type_for_var_declare_str = ""
                if dialect_name == "pd_op" or dialect_name == "onednn_op":
                    get_kernel_type_for_var_declare_str = (
                        get_kernel_type_for_var_declare_template
                    )

                parse_kernel_key_str = ""
                if (
                    "paddle::dialect::ParseKernelKeyInterface"
                    in all_interface_list
                ):
                    parse_kernel_key_str = parse_kernel_key_template

                infer_symbolic_shape_str = ""
                if (
                    "paddle::dialect::InferSymbolicShapeInterface"
                    in all_interface_list
                ):
                    infer_symbolic_shape_str = infer_symbolic_shape_template

                cache_grad_op_symbolic_shape_str = ""
                if op_info.backward_name:
                    cache_grad_op_symbolic_shape_str = (
                        cache_grad_op_symbolic_shape_template
                    )

                if op_infer_meta_map is not None:
                    (
                        build_args_with_muta_attr_not_input_for_declare,
                        build_func_with_muta_attr_not_input,
                    ) = gen_build_func_str(
                        args,
                        op_info,
                        op_attribute_name_list,
                        op_attribute_type_list,
                        op_attribute_build_arg_type_list,
                        op_attribute_default_value_list,
                        op_mutable_attribute_name_list,
                        op_mutable_attribute_type_list,
                        op_non_mutable_attribute_name_list,
                        op_non_mutable_attribute_type_list,
                        op_non_mutable_attribute_build_arg_type_list,
                        op_non_mutable_attribute_default_value_list,
                        muta_attr_is_input=False,
                    )
                    if len(op_attribute_name_list) > 0:
                        (
                            build_args_with_attr_is_map_for_declare,
                            build_func_with_attr_is_map,
                        ) = gen_build_func_str(
                            args,
                            op_info,
                            op_attribute_name_list,
                            op_attribute_type_list,
                            op_attribute_build_arg_type_list,
                            op_attribute_default_value_list,
                            op_mutable_attribute_name_list,
                            op_mutable_attribute_type_list,
                            op_non_mutable_attribute_name_list,
                            op_non_mutable_attribute_type_list,
                            op_non_mutable_attribute_build_arg_type_list,
                            op_non_mutable_attribute_default_value_list,
                            muta_attr_is_input=False,
                            attr_args_is_map=True,
                        )
                        build_attr_num_over_1 = f"static void Build({build_args_with_attr_is_map_for_declare});"
                    if len(op_mutable_attribute_name_list) > 0:
                        (
                            build_args_with_muta_attr_is_input_for_declare,
                            build_func_with_muta_attr_is_input,
                        ) = gen_build_func_str(
                            args,
                            op_info,
                            op_attribute_name_list,
                            op_attribute_type_list,
                            op_attribute_build_arg_type_list,
                            op_attribute_default_value_list,
                            op_mutable_attribute_name_list,
                            op_mutable_attribute_type_list,
                            op_non_mutable_attribute_name_list,
                            op_non_mutable_attribute_type_list,
                            op_non_mutable_attribute_build_arg_type_list,
                            op_non_mutable_attribute_default_value_list,
                            muta_attr_is_input=True,
                        )

                        build_mutable_attr_is_input = f"static void Build({build_args_with_muta_attr_is_input_for_declare});"
                # TODO(huangjiyi): support invoke op for sparse op.
                if (
                    (op_invoke_map is not None)
                    and (not op_info.is_sparse_op)
                    and (op_invoke_map['func'] in all_op_info_items)
                ):
                    op_invoke_class_name = (
                        to_pascal_case(op_invoke_map['func']) + "Op"
                    )

                    (
                        build_args_with_muta_attr_not_input_for_declare,
                        build_func_with_muta_attr_not_input,
                    ) = gen_build_func_str_by_invoke(
                        op_class_name,
                        op_input_name_list,
                        op_input_type_list,
                        op_input_optional_list,
                        op_attribute_name_list,
                        op_attribute_type_list,
                        op_attribute_build_arg_type_list,
                        op_attribute_default_value_list,
                        op_mutable_attribute_name_list,
                        op_mutable_attribute_type_list,
                        op_non_mutable_attribute_name_list,
                        op_non_mutable_attribute_type_list,
                        op_non_mutable_attribute_build_arg_type_list,
                        op_non_mutable_attribute_default_value_list,
                        op_invoke_class_name,
                        op_invoke_map,
                    )
                # gen op_declare_str/op_defined_str
                TEST_API = ""
                if op_class_name in need_export_symbol_op_list:
                    TEST_API = "TEST_API"
                if len(op_non_mutable_attribute_name_list) == 0:
                    op_declare_str = OP_DECLARE_TEMPLATE.format(
                        TEST_API=TEST_API,
                        op_name=op_class_name,
                        dialect_op_name=op_dialect_name,
                        interfaces=all_interface_str,
                        traits=op_traits_str,
                        attribute_declare=op_0_attribute_declare_str,
                        attribute_num=0,
                        build_args=build_args_with_muta_attr_not_input_for_declare,
                        build_mutable_attr_is_input=build_mutable_attr_is_input,
                        build_attr_num_over_1=build_attr_num_over_1,
                        build_mutable_attr_is_input_attr_num_over_1=build_mutable_attr_is_input_attr_num_over_1,
                        exclusive_interface=all_declare_str,
                        get_kernel_type_for_var_declare=get_kernel_type_for_var_declare_str,
                        parse_kernel_key_declare=parse_kernel_key_str,
                        infer_symbolic_shape_declare=infer_symbolic_shape_str,
                        cache_grad_op_symbolic_shape_declare=cache_grad_op_symbolic_shape_str,
                    )
                    op_defined_str = ""
                else:
                    op_declare_str = OP_DECLARE_TEMPLATE.format(
                        TEST_API=TEST_API,
                        op_name=op_class_name,
                        dialect_op_name=op_dialect_name,
                        interfaces=all_interface_str,
                        traits=op_traits_str,
                        attribute_declare=op_n_attribute_declare_str.format(
                            attribute_num=len(
                                op_non_mutable_attribute_name_list
                            )
                        ),
                        attribute_num=len(op_non_mutable_attribute_name_list),
                        build_args=build_args_with_muta_attr_not_input_for_declare,
                        build_mutable_attr_is_input=build_mutable_attr_is_input,
                        build_attr_num_over_1=build_attr_num_over_1,
                        build_mutable_attr_is_input_attr_num_over_1=build_mutable_attr_is_input_attr_num_over_1,
                        exclusive_interface=all_declare_str,
                        get_kernel_type_for_var_declare=get_kernel_type_for_var_declare_str,
                        parse_kernel_key_declare=parse_kernel_key_str,
                        infer_symbolic_shape_declare=infer_symbolic_shape_str,
                        cache_grad_op_symbolic_shape_declare=cache_grad_op_symbolic_shape_str,
                    )
                    attribute_names_str = (
                        '"'
                        + '", "'.join(op_non_mutable_attribute_name_list)
                        + '"'
                    )
                    op_defined_str = OP_N_ATTRIBUTE_DEFINED_TEMPLATE.format(
                        op_name=op_class_name,
                        attribute_num=len(op_non_mutable_attribute_name_list),
                        attribute_names=attribute_names_str,
                    )

                # =================================== #
                #         gen GetOpInfo func str      #
                # =================================== #
                # generate get op info function: inputs
                input_info_list = []
                for idx in range(len(op_input_name_list)):
                    input_info_list.append(
                        CONSTRUCT_INPUT_INFO_TEMPLATE.format(
                            name=op_input_name_list[idx],
                            typename=op_input_type_list[idx],
                            optional=op_input_optional_list[idx],
                            no_need_buffer=op_input_no_need_buffer_list[idx],
                            is_mutable_attribute='false',
                            with_grad_semantic=input_grad_semantics[idx],
                        )
                    )
                for idx in range(len(op_mutable_attribute_name_list)):
                    input_info_list.append(
                        CONSTRUCT_INPUT_INFO_TEMPLATE.format(
                            name=op_mutable_attribute_name_list[idx],
                            typename=op_mutable_attribute_type_list[idx][0],
                            optional='false',
                            no_need_buffer='false',
                            is_mutable_attribute='true',
                            with_grad_semantic=mutable_attribute_grad_semantics[
                                idx
                            ],
                        )
                    )
                if len(input_info_list) > 0:
                    inputs_info_str = ", ".join(input_info_list)
                else:
                    inputs_info_str = ""
                # generate get op info function: outputs
                outputs_info_str = ""
                if len(op_output_name_list) > 0:
                    output_info_list = []
                    for idx in range(len(op_output_name_list)):
                        output_info_list.append(
                            CONSTRUCT_OUTPUT_INFO_TEMPLATE.format(
                                name=op_output_name_list[idx],
                                typename=op_output_type_list[idx],
                                optional=op_output_optional_list[idx],
                                intermediate=op_output_intermediate_list[idx],
                            )
                        )
                    outputs_info_str = ", ".join(output_info_list)
                # generate get op info function: attributes
                attribute_info_str = ""
                if len(op_non_mutable_attribute_name_list) > 0:
                    attribute_info_list = []
                    for idx in range(len(op_non_mutable_attribute_name_list)):
                        attribute_info_list.append(
                            CONSTRUCT_ATTRIBUTE_INFO_TEMPLATE.format(
                                name=op_non_mutable_attribute_name_list[idx],
                                typename=op_non_mutable_attribute_type_list[
                                    idx
                                ],
                                data_type=op_non_mutable_attribute_data_type_list[
                                    idx
                                ],
                            )
                        )
                    attribute_info_str = ", ".join(attribute_info_list)
                # generate runtiem info
                infer_meta_func_str = ""
                infer_meta_param_str = ""
                if op_infer_meta_map is not None:
                    infer_meta_func_str = op_infer_meta_map['func']
                    infer_meta_param_str = '", "'.join(
                        str(item) for item in op_infer_meta_map['param']
                    )

                kernel_func_str = ""
                kernel_param_str = ""
                kernel_key_dtype = ""
                kernel_key_backend = ""
                if op_kernel_map is not None:
                    kernel_func_str = kernel_func_name
                    kernel_param_str = '", "'.join(op_kernel_map['param'])
                    if op_kernel_map.get('data_type'):
                        for idx in range(
                            len(op_kernel_map['data_type']['candidates'])
                        ):
                            if (
                                'to_complex_flag' in op_kernel_map['data_type']
                                and op_kernel_map['data_type'][
                                    'to_complex_flag'
                                ][idx]
                            ):
                                kernel_key_dtype += (
                                    'complex:'
                                    + op_kernel_map['data_type']['candidates'][
                                        idx
                                    ]
                                    + '", "'
                                )
                            else:
                                kernel_key_dtype += (
                                    op_kernel_map['data_type']['candidates'][
                                        idx
                                    ]
                                    + '", "'
                                )
                        if kernel_key_dtype != "":
                            kernel_key_dtype = '"' + kernel_key_dtype[:-3]
                    if op_kernel_map.get('backend'):
                        kernel_key_backend = '", "'.join(
                            op_kernel_map['backend']['candidates']
                        )
                        if kernel_key_backend != "":
                            kernel_key_backend = '"' + kernel_key_backend + '"'

                inplace_str = ""
                view_str = ""
                if op_name[-1] == "_":
                    if op_inplace_map is not None:
                        for key, value in op_inplace_map.items():
                            inplace_str += '{"' + key + '", "' + value + '"},'
                        inplace_str = inplace_str[:-1]
                    if op_view_map is not None:
                        for key, value in op_view_map.items():
                            view_str += '{"' + key + '", "' + value + '"},'
                        view_str = view_str[:-1]

                op_info_func_str = OP_INFO_TEMPLATE.format(
                    op_name=op_class_name,
                    inputs=inputs_info_str,
                    attributes=attribute_info_str,
                    outputs=outputs_info_str,
                    infer_meta_func=infer_meta_func_str,
                    infer_meta_param=infer_meta_param_str,
                    kernel_func=kernel_func_str,
                    kernel_param=kernel_param_str,
                    kernel_key_dtype=kernel_key_dtype,
                    kernel_key_backend=kernel_key_backend,
                    inplace=inplace_str,
                    view=view_str,
                    origin_op_name=op_info.op_yaml_item['name'],
                )

                if dialect_name == "onednn_op":
                    if (
                        op_info.onednn_extra_args is not None
                        and len(op_info.onednn_extra_args) > 0
                    ):
                        args_name = []
                        for arg in op_info.onednn_extra_args:
                            args_name.append(arg["name"])

                        extra_args = '"' + '", "'.join(args_name) + '"'
                    else:
                        extra_args = ""
                    if op_info.onednn_data_format_tensors is None:
                        data_format_tensors = ""
                    else:
                        data_format_tensors = op_info.onednn_data_format_tensors
                        data_format_tensors = (
                            '"' + '", "'.join(data_format_tensors) + '"'
                        )
                    if (
                        op_info.onednn_extra_args is not None
                        and len(op_info.onednn_extra_args) > 0
                    ):
                        extra_attr_default_value_code_str = (
                            GenOneDnnExtraAttrsDefaultValue(
                                op_info.onednn_extra_args
                            )
                        )
                    else:
                        extra_attr_default_value_code_str = ""
                    skip_transform_inputs = ""
                    if op_info.data_transform_map is not None:
                        if "skip_transform" in op_info.data_transform_map:
                            skip_transform = op_info.data_transform_map[
                                "skip_transform"
                            ]
                            if skip_transform is not None:
                                skip_transform_input_names = []
                                for input in skip_transform:
                                    skip_transform_input_names.append(input)

                                skip_transform_inputs = (
                                    '"'
                                    + '", "'.join(skip_transform_input_names)
                                    + '"'
                                )

                    op_info_func_str = OP_INFO_ONEDNN_TEMPLATE.format(
                        op_name=op_class_name,
                        inputs=inputs_info_str,
                        attributes=attribute_info_str,
                        outputs=outputs_info_str,
                        extra_attr_default_value_code=extra_attr_default_value_code_str,
                        infer_meta_func=infer_meta_func_str,
                        infer_meta_param=infer_meta_param_str,
                        kernel_func=kernel_func_str,
                        kernel_param=kernel_param_str,
                        kernel_key_dtype=kernel_key_dtype,
                        kernel_key_backend=kernel_key_backend,
                        inplace=inplace_str,
                        view=view_str,
                        origin_op_name=op_info.op_yaml_item['name'],
                        extra_args=extra_args,
                        skip_transform_inputs=skip_transform_inputs,
                        data_format_tensors=data_format_tensors,
                        is_onednn_only=(
                            "true" if op_info.is_onednn_only else "false"
                        ),
                        dynamic_fallback=(
                            "true" if op_info.dynamic_fallback else "false"
                        ),
                    )
                # generate op verify function str
                op_verify_str = ''
                if not op_info.custom_verify:
                    op_verify_str = gen_verify_func_str(
                        op_class_name,
                        op_input_type_list,
                        op_input_optional_list,
                        op_mutable_attribute_name_list,
                        op_mutable_attribute_type_list,
                        op_non_mutable_attribute_name_list,
                        op_non_mutable_attribute_type_list,
                        op_output_type_list,
                        op_output_optional_list,
                    )

                # generate op ParseKernelKeyInterface function str
                parse_kernel_key_define_str = ''
                if (
                    "paddle::dialect::ParseKernelKeyInterface"
                    in all_interface_list
                ):
                    parse_kernel_key_define_str = gen_parse_kernel_key_str(
                        op_class_name
                    )

                # generate op InferSymbolicShapeInterface function str
                infer_symbolic_shape_define_str = ''
                if (
                    "paddle::dialect::InferSymbolicShapeInterface"
                    in all_interface_list
                ):
                    infer_symbolic_shape_define_str = (
                        gen_infer_symbolic_shape_str(op_class_name)
                    )

                # generate op GetKernelKeyForVar function str
                op_get_kernel_type_for_var_str = ''
                if dialect_name == "pd_op" or dialect_name == "onednn_op":
                    op_get_kernel_type_for_var_str = (
                        gen_kernel_type_for_var_str(
                            op_class_name,
                            op_data_transform_map,
                            op_kernel_map,
                            op_info.op_compat_item,
                        )
                    )

                # =================================== #
                #         gen Vjp func str      #
                # =================================== #

                # generate op vjp function str
                op_vjp_str = ''
                if dialect_name == "cinn":
                    logging.warning(
                        "cinn is currently not support Vjp function"
                    )
                else:
                    if (
                        op_info.backward_name
                        and op_info.op_phi_name[0]
                        not in vjp_interface_black_list
                        and dialect_name != "onednn_op"
                    ):
                        sparse_op_name_suffix = (
                            '_sp' if op_info.is_sparse_op else ''
                        )
                        op_vjp_str = gen_op_vjp_str(
                            op_class_name,
                            op_info.backward_name,
                            op_name,
                            op_info_items[
                                op_info.op_phi_name[0] + sparse_op_name_suffix
                            ],
                            all_op_info_items[
                                op_info.backward_name + sparse_op_name_suffix
                            ],
                        )

                    ops_name_list.append(op_class_name)
                    ops_declare_list.append(op_declare_str)
                    ops_defined_list.append(op_defined_str)
                    ops_defined_list.append(op_info_func_str)
                    ops_defined_list.append(build_func_with_muta_attr_not_input)
                    ops_defined_list.append(build_func_with_attr_is_map)
                    if len(op_mutable_attribute_name_list) > 0:
                        ops_defined_list.append(
                            build_func_with_muta_attr_is_input
                        )
                        ops_defined_list.append(
                            build_func_with_muta_attr_is_input_with_attr_is_map
                        )

                    ops_defined_list.append(op_verify_str)
                    ops_defined_list.append(op_get_kernel_type_for_var_str)
                    ops_defined_list.append(parse_kernel_key_define_str)
                    ops_defined_list.append(infer_symbolic_shape_define_str)

                    # NOTE(chenxi67)skip if dialect_name==cinn
                    if dialect_name == "cinn" or dialect_name == "onednn_op":
                        pass
                    else:
                        ops_vjp_defined_list.append(op_vjp_str)

            if op_kernel_map is not None and len(op_kernel_map['func']) > 1:
                OP_TO_MULTI_KERNELS_MAP_ITEM = (
                    """{{"{op_name}", {{{sig_list}}}}}"""
                )
                OP_TO_MULTI_KERNELS_MAP_ITEM_SIG = """paddle::dialect::PdOpSig("{kernel_name}", {{{inputs}}}, {{{outputs}}})"""
                op_to_multi_kernels_sig_list = []
                if op_info.is_sparse_op:
                    SP_OP_TO_MULTI_KERNELS_MAP_ITEM = (
                        """{{"{op_name}", {{{sig_list}}}}}"""
                    )
                    SP_OP_TO_MULTI_KERNELS_MAP_ITEM_SIG = """paddle::dialect::PdOpSig("{kernel_name}", {{{inputs}}}, {{{outputs}}})"""
                    sp_op_to_multi_kernels_sig_list = []
                for kernel_func_name in op_kernel_map['func']:
                    inputs = op_kernel_map['dispatch'][kernel_func_name][0]
                    outputs = op_kernel_map['dispatch'][kernel_func_name][1]
                    inputs = '"' + '", "'.join(inputs) + '"'
                    outputs = '"' + '", "'.join(outputs) + '"'
                    if op_name[-1] == "_":
                        kernel_func_name = kernel_func_name + "_"
                    if op_info.is_sparse_op:
                        sp_op_to_multi_kernels_sig_list.append(
                            SP_OP_TO_MULTI_KERNELS_MAP_ITEM_SIG.format(
                                kernel_name=kernel_func_name,
                                inputs=inputs,
                                outputs=outputs,
                            )
                        )
                    else:
                        op_to_multi_kernels_sig_list.append(
                            OP_TO_MULTI_KERNELS_MAP_ITEM_SIG.format(
                                kernel_name=kernel_func_name,
                                inputs=inputs,
                                outputs=outputs,
                            )
                        )
                if op_info.is_sparse_op:
                    op_name += "_sp"
                    sp_op_to_multi_kernels_str = (
                        SP_OP_TO_MULTI_KERNELS_MAP_ITEM.format(
                            op_name=op_name,
                            sig_list=", ".join(sp_op_to_multi_kernels_sig_list),
                        )
                    )
                    sp_op_to_multi_kernels_list.append(
                        sp_op_to_multi_kernels_str
                    )
                else:
                    op_to_multi_kernels_str = (
                        OP_TO_MULTI_KERNELS_MAP_ITEM.format(
                            op_name=op_name,
                            sig_list=", ".join(op_to_multi_kernels_sig_list),
                        )
                    )
                    op_to_multi_kernels_list.append(op_to_multi_kernels_str)

    # GET_OP_LIST for cc
    op_namespaces_prev = ""
    for name in namespaces:
        op_namespaces_prev += name + "::"
    ops_name_with_namespace_list = []
    for name in ops_name_list:
        ops_name_with_namespace_list.append(op_namespaces_prev + name)
    op_list_str = GET_OP_LIST_TEMPLATE.format(
        ", ".join(ops_name_with_namespace_list)
    )

    # IR_DECLARE_EXPLICIT_TYPE_ID for h
    declare_type_id_str = ""
    for op in ops_name_with_namespace_list:
        declare_type_id_str += DECLARE_OP_TYPE_ID.format(op_name=op)

    # IR_DEFINE_EXPLICIT_TYPE_ID for cc
    define_type_id_str = ""
    for op in ops_name_with_namespace_list:
        define_type_id_str += DEFINE_OP_TYPE_ID.format(op_name=op)

    # Op Declare for h
    head_file_str = "".join(ops_declare_list)

    # OpDefine for cc
    source_file_str = "".join(ops_defined_list)  # Add op define

    # VJP for cc
    vjp_source_file_str = "".join(ops_vjp_defined_list)

    return (
        op_list_str,
        declare_type_id_str,
        define_type_id_str,
        head_file_str,
        source_file_str,
        op_to_multi_kernels_list,
        sp_op_to_multi_kernels_list,
        vjp_source_file_str,
    )


def OpGenerator(
    args: argparse.Namespace,
    op_yaml_files: list[str],
    op_compat_yaml_file: str,
    namespaces: list[str],
    dialect_name: str,
    op_def_h_file: str,
    op_info_file: str,
    op_def_cc_file: list[str],
    op_vjp_cc_file: str,
    op_cc_split_num: int,
    bwd_op_cc_split_num: int,
    onednn_yaml_file: str | None,
    ops_onednn_extra_yaml_file: str | None,
):
    # (1) Prepare: Delete existing old files: pd_op.h.tmp, pd_op.cc.tmp
    if os.path.exists(op_def_h_file):
        os.remove(op_def_h_file)
    for op_file in op_def_cc_file:
        if os.path.exists(op_file):
            os.remove(op_file)
    if (op_vjp_cc_file is not None) and (os.path.exists(op_vjp_cc_file)):
        os.remove(op_vjp_cc_file)

    # (2) parse yaml files
    op_compat_parser = OpCompatParser(op_compat_yaml_file)

    if dialect_name == "onednn_op":
        if onednn_yaml_file is None or ops_onednn_extra_yaml_file is None:
            raise ValueError(
                "onednn_op should provide onednn_yaml_file and ops_onednn_extra_yaml_file"
            )
        with open(ops_onednn_extra_yaml_file, "r") as f:
            ops_onednn_extra = yaml.safe_load(f)
            ops_onednn_extra_map = {}
            for op in ops_onednn_extra:
                op_name = op['op']
                item = {}
                item["is_onednn_only"] = False
                if 'extra_args' in op:
                    item["extra_args"] = parse_extra_args(
                        op_name, op['extra_args']
                    )
                    item["attrs"] = parse_extra_args(op_name, op['extra_args'])
                else:
                    item["extra_args"] = None
                    item["attrs"] = None
                if 'data_format_tensors' in op:
                    item["data_format_tensors"] = parse_data_format_tensors(
                        op_name, op['data_format_tensors']
                    )
                else:
                    item["data_format_tensors"] = None
                if 'dynamic_fallback' in op:
                    item["dynamic_fallback"] = op['dynamic_fallback']
                else:
                    item["dynamic_fallback"] = False
                ops_onednn_extra_map[op_name] = item
        op_yaml_files.insert(0, onednn_yaml_file)

    op_infos: list[dict[str, OpInfoParser]] = []
    all_op_info_items: dict[str, OpInfoParser] = {}
    new_op_def_cc_file = []
    first_file = True
    onednn_only_op_list = []
    for idx in range(len(op_yaml_files)):
        yaml_file = op_yaml_files[idx]
        op_yaml_items = []
        with open(yaml_file, "r") as f:
            ops = yaml.safe_load(f)
            op_yaml_items = op_yaml_items + ops
        op_info_items = {}
        for op in op_yaml_items:
            op_compat_item = None
            if dialect_name == "pd_op" or dialect_name == "onednn_op":
                op_compat_item = op_compat_parser.get_compat(op['name'])

            if (
                op_compat_item is None
                and op['name'].endswith(('_grad', '_grad_'))
                and 'forward' in op
            ):
                op_compat_item = op_compat_parser.get_compat(
                    op['forward']['name']
                )

            if (
                op_compat_item is not None
                and op_compat_item['op'] == "pow"
                and 'scalar' in op_compat_item
            ):
                op_compat_item = op_compat_item.pop('scalar')

            if (
                op_compat_item is not None
                and 'support_tensor' in op.keys()
                and op['support_tensor']
            ):
                (
                    scalar_item,
                    int_array_item,
                ) = op_compat_parser.parse_support_tensor(op)
                op_compat_item['scalar'] = scalar_item
                op_compat_item['int_array'] = int_array_item
            if dialect_name == "onednn_op":
                if first_file:
                    op["is_onednn_only"] = True
                    onednn_only_op_list.append('"' + op['name'] + '"')
                    if op['name'] in ops_onednn_extra_map:
                        onednn_item = ops_onednn_extra_map[op['name']]
                        op["is_onednn_only"] = onednn_item["is_onednn_only"]
                        op["extra_args"] = onednn_item["extra_args"]
                        op["data_format_tensors"] = onednn_item[
                            "data_format_tensors"
                        ]
                        op["dynamic_fallback"] = onednn_item["dynamic_fallback"]
                        if onednn_item["attrs"] is not None:
                            op["attrs"] = op["attrs"] + onednn_item["attrs"]
                        else:
                            op["attrs"] = op["attrs"]
                elif op['name'] in ops_onednn_extra_map:
                    onednn_item = ops_onednn_extra_map[op['name']]
                    op["is_onednn_only"] = onednn_item["is_onednn_only"]
                    op["extra_args"] = onednn_item["extra_args"]
                    op["data_format_tensors"] = onednn_item[
                        "data_format_tensors"
                    ]
                    op["dynamic_fallback"] = onednn_item["dynamic_fallback"]
                    if onednn_item["attrs"] is not None:
                        op["attrs"] = op["attrs"] + onednn_item["attrs"]
                else:
                    continue
            item = OpInfoParser(op, op_compat_item, yaml_file)
            key_suffix = '_sp' if item.is_sparse_op else ''
            op_info_items[op['name'] + key_suffix] = item
            all_op_info_items[op['name'] + key_suffix] = item

        if dialect_name != "onednn_op":
            cc_file = op_def_cc_file[idx]
            if (
                yaml_file.split('/')[-1] == "ops.parsed.yaml"
                and op_cc_split_num is not None
            ):
                split_op_info_items, split_cc_files = split_ops(
                    op_info_items, cc_file, op_cc_split_num
                )
                op_infos.extend(split_op_info_items)
                new_op_def_cc_file.extend(split_cc_files)
            elif (
                yaml_file.split('/')[-1] == "backward.parsed.yaml"
                and bwd_op_cc_split_num is not None
            ):
                split_op_info_items, split_cc_files = split_ops(
                    op_info_items, cc_file, bwd_op_cc_split_num
                )
                op_infos.extend(split_op_info_items)
                new_op_def_cc_file.extend(split_cc_files)
            else:
                op_infos.append(op_info_items)
                new_op_def_cc_file.append(cc_file)

        if first_file:
            first_file = False

    if dialect_name == "onednn_op":
        op_infos = [all_op_info_items]
        new_op_def_cc_file = op_def_cc_file
    # (3) auto code gen
    op_list_strs = []
    declare_type_id_strs = []
    define_type_id_strs = []
    head_file_strs = []
    source_file_strs = []
    op_to_multi_kernels_lists = []
    sp_op_to_multi_kernels_lists = []
    vjp_source_file_strs = []
    for items in op_infos:
        (
            op_list_str,
            declare_type_id_str,
            define_type_id_str,
            head_file_str,
            source_file_str,
            op_to_multi_kernels_list,
            sp_op_to_multi_kernels_list,
            vjp_source_file_str,
        ) = AutoCodeGen(
            args, items, all_op_info_items, namespaces, dialect_name
        )
        op_list_strs.append(op_list_str)
        declare_type_id_strs.append(declare_type_id_str)
        define_type_id_strs.append(define_type_id_str)
        head_file_strs.append(head_file_str)
        source_file_strs.append(source_file_str)
        sp_op_to_multi_kernels_lists = (
            sp_op_to_multi_kernels_lists + sp_op_to_multi_kernels_list
        )
        op_to_multi_kernels_lists = (
            op_to_multi_kernels_lists + op_to_multi_kernels_list
        )
        vjp_source_file_strs.append(vjp_source_file_str)

    # (4) write to files for pd_op.h.tmp (all yaml gen to pd_op.h)
    only_pd_op_header_files_str = ""
    if dialect_name == "pd_op":
        other_info = OP_TO_MULTI_KERNELS_MAP_H
        for name in reversed(namespaces):
            other_info = NAMESPACE_GUARD_TEMPLATE.format(
                namespace=name, input=other_info
            )  # Add namespaces
        only_pd_op_header_files_str = """
#include \"paddle/phi/common/data_type.h\"
#include \"paddle/fluid/pir/dialect/operator/interface/get_kernel_type_for_var.h\"
            """
    elif dialect_name == "onednn_op":
        other_info = ONEDNN_ONLY_OP_SET_H
        for name in reversed(namespaces):
            other_info = NAMESPACE_GUARD_TEMPLATE.format(
                namespace=name, input=other_info
            )  # Add namespaces
    else:
        other_info = ""

    head_file_str = "\n".join(head_file_strs)
    declare_type_id_str = "\n".join(declare_type_id_strs)
    for name in reversed(namespaces):
        head_file_str = NAMESPACE_GUARD_TEMPLATE.format(
            namespace=name, input=head_file_str
        )  # Add namespaces
    head_file_str = H_FILE_TEMPLATE.format(
        other_info=other_info,
        input=head_file_str,
        declare_type_id=declare_type_id_str,
        only_pd_op_header_files=only_pd_op_header_files_str,
    )
    with open(op_def_h_file, 'w') as f:
        f.write(head_file_str)

    # (5) write to files for pd_op_info.tmp
    if dialect_name == "pd_op":
        other_info_str = OP_TO_MULTI_KERNELS_MAPS.format(
            maps=", \r".join(op_to_multi_kernels_lists)
        )
        sp_other_info_str = SP_OP_TO_MULTI_KERNELS_MAPS.format(
            maps=", \r".join(sp_op_to_multi_kernels_lists)
        )
        other_info_str += sp_other_info_str
        for name in reversed(namespaces):
            other_info_str = NAMESPACE_GUARD_TEMPLATE.format(
                namespace=name, input=other_info_str
            )  # Add namespaces
    elif dialect_name == "onednn_op":
        other_info_str = ONEDNN_ONLY_OP_SET.format(
            maps=", \r".join(onednn_only_op_list)
        )
        for name in reversed(namespaces):
            other_info_str = NAMESPACE_GUARD_TEMPLATE.format(
                namespace=name, input=other_info_str
            )  # Add namespaces
    else:
        other_info_str = ""

    if op_info_file is not None:
        if sys.platform == "win32":
            n = len(op_list_strs) // 4
            first_part_op_info = op_list_strs[:n]
            second_part_op_info = op_list_strs[n : 2 * n]
            third_part_op_info = op_list_strs[2 * n : 3 * n]
            fourth_part_op_info = op_list_strs[3 * n :]
            CC_OP_INFO_FILE_TEMPLATE = (
                CC_OP_INFO_FILE_TEMPLATE_WIN_PART1
                + CC_OP_INFO_FILE_TEMPLATE_PART2
            )
            op_info_str = CC_OP_INFO_FILE_TEMPLATE.format(
                op_declare_first_part=",".join(first_part_op_info).replace(
                    "\n", ""
                ),
                op_declare_second_part=",".join(second_part_op_info).replace(
                    "\n", ""
                ),
                op_declare_third_part=",".join(third_part_op_info).replace(
                    "\n", ""
                ),
                op_declare_fourth_part=",".join(fourth_part_op_info).replace(
                    "\n", ""
                ),
                other_info=other_info_str,
                h_file=op_def_h_file[:-4],
            )
        else:
            CC_OP_INFO_FILE_TEMPLATE = (
                CC_OP_INFO_FILE_TEMPLATE_PART1 + CC_OP_INFO_FILE_TEMPLATE_PART2
            )
            op_info_str = CC_OP_INFO_FILE_TEMPLATE.format(
                op_declare=",".join(op_list_strs).replace("\n", ""),
                other_info=other_info_str,
                h_file=op_def_h_file[:-4],
            )

        with open(op_info_file, 'w') as f:
            f.write(op_info_str)

    # (6) write to files for xx_op.cc.tmp
    for id in range(len(new_op_def_cc_file)):
        source_file_str = source_file_strs[id]
        for name in reversed(namespaces):
            source_file_str = NAMESPACE_GUARD_TEMPLATE.format(
                namespace=name, input=source_file_str
            )  # Add namespaces

        if dialect_name == "onednn_op":
            op_def_h_file_tmp = (
                'paddle/fluid/pir/dialect/operator/ir/pd_op.h"\n#include "'
                + op_def_h_file
            )
        else:
            op_def_h_file_tmp = op_def_h_file

        source_file_str = CC_FILE_TEMPLATE.format(
            h_file=op_def_h_file_tmp[:-4],
            input=source_file_str,
            define_type_id=define_type_id_strs[id],
        )
        with open(new_op_def_cc_file[id], 'w') as f:
            f.write(source_file_str)

    # (6) write to files for xx_vjp_op.cc.tmp
    # NOTE(Aurelius84): op_gen.py is called multiply times,
    # and vjp is only available for pd dialect.
    vjp_source_file_str = "\n".join(vjp_source_file_strs)
    vjp_source_file_str = VJP_CC_FILE_TEMPLATE.format(input=vjp_source_file_str)
    if (
        dialect_name != 'cinn'
        and dialect_name != 'onednn_op'
        and op_vjp_cc_file
    ):
        with open(op_vjp_cc_file, 'w') as f:
            f.write(vjp_source_file_str)


def strtobool(val):
    val = val.lower()
    if val in ['y', 'yes', 't', 'true', 'on', '1']:
        return True
    elif val in ['n', 'no', 'f', 'false', 'off', '0']:
        return False
    else:
        raise ValueError(f"Invalid truth value {val!r}")


# =====================================
# Script parameter parsing
# =====================================
def ParseArguments():
    parser = argparse.ArgumentParser(
        description='Generate Dialect OP Definition Files By Yaml'
    )
    parser.add_argument('--op_yaml_files', type=str)
    parser.add_argument('--op_compat_yaml_file', type=str)
    parser.add_argument('--namespaces', type=str)
    parser.add_argument('--dialect_name', type=str)
    parser.add_argument('--op_def_h_file', type=str)
    parser.add_argument('--op_info_file', type=str)
    parser.add_argument('--op_def_cc_file', type=str)
    parser.add_argument('--op_vjp_cc_file', type=str)
    parser.add_argument('--op_cc_split_num', type=int)
    parser.add_argument('--bwd_op_cc_split_num', type=int)
    parser.add_argument('--onednn_yaml_file', type=str)
    parser.add_argument('--ops_onednn_extra_yaml_file', type=str)
    parser.add_argument('--with_distributed', type=strtobool)
    return parser.parse_args()


# =====================================
# Main
# =====================================
if __name__ == "__main__":
    # parse arguments
    args = ParseArguments()
    op_yaml_files = args.op_yaml_files.split(",")
    op_compat_yaml_file = args.op_compat_yaml_file
    namespaces = []
    if args.namespaces is not None:
        namespaces = args.namespaces.split(",")
    dialect_name = args.dialect_name
    op_def_h_file = args.op_def_h_file
    op_info_file = args.op_info_file
    op_def_cc_files = args.op_def_cc_file.split(",")
    op_vjp_cc_file = args.op_vjp_cc_file
    op_cc_split_num = args.op_cc_split_num
    bwd_op_cc_split_num = args.bwd_op_cc_split_num
    onednn_yaml_file = args.onednn_yaml_file
    ops_onednn_extra_yaml_file = args.ops_onednn_extra_yaml_file

    # auto code generate
    OpGenerator(
        args,
        op_yaml_files,
        op_compat_yaml_file,
        namespaces,
        dialect_name,
        op_def_h_file,
        op_info_file,
        op_def_cc_files,
        op_vjp_cc_file,
        op_cc_split_num,
        bwd_op_cc_split_num,
        onednn_yaml_file,
        ops_onednn_extra_yaml_file,
    )
